import torch
from torch import nn

class DenoiseCNN(nn.Module):
    def __init__(self):
        super(DenoiseCNN, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.fc1 = nn.Sequential(nn.Linear(64 * 7 * 7, 64 * 7 * 7),
                                 nn.ReLU())
        
        self.fc2 = nn.Sequential(nn.Linear(65 * 7 * 7, 65 * 7 * 7),
                                 nn.ReLU())
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(65, 32, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=2, stride=2),
            nn.Tanh()
        )

    def forward(self, x, eta):
        x = self.encoder(x)
        
        x = x.view(-1, 64 * 7 * 7)
        
        x = self.fc1(x)
        x = self.fc1(x)
        
        x = x.view(-1, 64, 7, 7)
        
        if len(eta.shape) == 3:
            eta = eta.unsqueeze(1)
            x = torch.cat((x, eta), dim = 1)
        elif len(eta.shape) == 2:
            eta = eta.unsqueeze(0)
            x = torch.cat((x, eta), dim = 0)
        
        x = x.view(-1, 65 * 7 * 7)        
        
        x = self.fc2(x)
        x = self.fc2(x)   
        
        x = x.view(-1, 65, 7, 7)    
                        
        x = self.decoder(x)
        
        return x